{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "70300319-d206-43ce-b3bf-3da6b079f20f",
   "metadata": {
    "id": "70300319-d206-43ce-b3bf-3da6b079f20f"
   },
   "source": [
    "## MusicGen in MindNLP\n",
    "\n",
    "MusicGen is a Transformer-based model capable fo generating high-quality music samples conditioned on text descriptions or audio prompts. It was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet et al. from Meta AI.\n",
    "\n",
    "The MusicGen model can be de-composed into three distinct stages:\n",
    "1. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations\n",
    "2. The MusicGen decoder is then trained to predict discrete audio tokens, or *audio codes*, conditioned on these hidden-states\n",
    "3. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform\n",
    "\n",
    "The pre-trained MusicGen checkpoints use Google's [t5-base](https://huggingface.co/t5-base) as the text encoder model, and [EnCodec 32kHz](https://huggingface.co/facebook/encodec_32khz) as the audio compression model. The MusicGen decoder is a pure language model architecture,\n",
    "trained from scratch on the task of music generation.\n",
    "\n",
    "The novelty in the MusicGen model is how the audio codes are predicted. Traditionally, each codebook has to be predicted by a separate model (i.e. hierarchically) or by continuously refining the output of the Transformer model (i.e. upsampling). MusicGen uses an efficient *token interleaving pattern*, thus eliminating the need to cascade multiple models to predict a set of codebooks. Instead, it is able to generate the full set of codebooks in a single forward pass of the decoder, resulting in much faster inference.\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"https://github.com/sanchit-gandhi/codesnippets/blob/main/delay_pattern.png?raw=true\" width=\"600\"/>\n",
    "</p>\n",
    "\n",
    "\n",
    "**Figure 1:** Codebook delay pattern used by MusicGen. Figure taken from the [MusicGen paper](https://arxiv.org/abs/2306.05284).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77ee39cc-654b-4f0e-b601-013e484c16f0",
   "metadata": {
    "id": "77ee39cc-654b-4f0e-b601-013e484c16f0"
   },
   "source": [
    "## Load the Model\n",
    "\n",
    "The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d87424-9f38-4658-ba47-2a465d52ad77",
   "metadata": {
    "id": "b0d87424-9f38-4658-ba47-2a465d52ad77"
   },
   "outputs": [],
   "source": [
    "from mindnlp.transformers import MusicgenForConditionalGeneration\n",
    "\n",
    "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547",
   "metadata": {
    "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547"
   },
   "source": [
    "## Generation\n",
    "\n",
    "MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly\n",
    "better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,\n",
    "and can be explicitly specified by setting `do_sample=True` in the call to `MusicgenForConditionalGeneration.generate` (see below).\n",
    "\n",
    "### Unconditional Generation\n",
    "\n",
    "The inputs for unconditional (or 'null') generation can be obtained through the method `MusicgenForConditionalGeneration.get_unconditional_inputs`. We can then run auto-regressive generation using the `.generate` method, specifying `do_sample=True` to enable sampling mode:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2",
   "metadata": {
    "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2"
   },
   "outputs": [],
   "source": [
    "unconditional_inputs = model.get_unconditional_inputs(num_samples=1)\n",
    "\n",
    "audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94cb74df-c194-4d2e-930a-12473b08a919",
   "metadata": {
    "id": "94cb74df-c194-4d2e-930a-12473b08a919"
   },
   "source": [
    "The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen\n",
    "to the generated audio samples, you can either play them in an ipynb notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea",
   "metadata": {
    "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "\n",
    "sampling_rate = model.config.audio_encoder.sampling_rate\n",
    "Audio(audio_values[0].asnumpy(), rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6de58334-40f7-4924-addb-2d6ff34c0590",
   "metadata": {
    "id": "6de58334-40f7-4924-addb-2d6ff34c0590"
   },
   "source": [
    "Or save them as a `.wav` file using a third-party library, e.g. `scipy` (note here that we also need to remove the channel dimension from our audio tensor):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04291f52-0a75-4ddb-9eff-e853d0f17288",
   "metadata": {
    "id": "04291f52-0a75-4ddb-9eff-e853d0f17288"
   },
   "outputs": [],
   "source": [
    "import scipy\n",
    "\n",
    "scipy.io.wavfile.write(\"musicgen_out.wav\", rate=sampling_rate, data=audio_values[0, 0].asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39",
   "metadata": {
    "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39"
   },
   "source": [
    "The argument `max_new_tokens` specifies the number of new tokens to generate. As a rule of thumb, you can work out the length of the generated audio sample in seconds by using the frame rate of the EnCodec model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a",
   "metadata": {
    "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a"
   },
   "outputs": [],
   "source": [
    "audio_length_in_s = 256 / model.config.audio_encoder.frame_rate\n",
    "\n",
    "audio_length_in_s"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581",
   "metadata": {
    "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581"
   },
   "source": [
    "### Text-Conditional Generation\n",
    "\n",
    "The model can generate an audio sample conditioned on a text prompt through use of the `MusicgenProcessor` to pre-process\n",
    "the inputs. The pre-processed inputs can then be passed to the `.generate` method to generate text-conditional audio samples.\n",
    "Again, we enable sampling mode by setting `do_sample=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fba4154-13f6-403a-958b-101d6eacfb6e",
   "metadata": {
    "id": "5fba4154-13f6-403a-958b-101d6eacfb6e"
   },
   "outputs": [],
   "source": [
    "from mindnlp.transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"facebook/musicgen-small\")\n",
    "\n",
    "inputs = processor(\n",
    "    text=[\"80s pop track with bassy drums and synth\", \"90s rock song with loud guitars and heavy drums\"],\n",
    "    padding=True,\n",
    "    return_tensors=\"ms\",\n",
    ")\n",
    "\n",
    "audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
    "\n",
    "Audio(audio_values[0].asnumpy(), rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0",
   "metadata": {
    "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0"
   },
   "source": [
    "The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits\n",
    "(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or\n",
    "'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input\n",
    "prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,\n",
    "use a `guidance_scale=3` (default) for text and audio-conditional generation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d391b2a1-6376-4b69-b562-4388b731cf60",
   "metadata": {
    "id": "d391b2a1-6376-4b69-b562-4388b731cf60"
   },
   "source": [
    "### Audio-Prompted Generation\n",
    "\n",
    "The same `MusicgenProcessor` can be used to pre-process an audio prompt that is used for audio continuation. In the\n",
    "following example, we load an audio file using the 🤗 Datasets library, pre-process it using the processor class,\n",
    "and then forward the inputs to the model for generation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8",
   "metadata": {
    "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8"
   },
   "outputs": [],
   "source": [
    "from mindnlp.dataset import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"sanchit-gandhi/gtzan\", split=\"train\", streaming=True)\n",
    "sample = next(iter(dataset.create_dict_iterator(output_numpy=True)))[\"audio\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79fd7ab3-4d1f-4838-aff8-13d6fa568b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# take the first half of the audio sample\n",
    "sample[\"array\"] = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
    "\n",
    "inputs = processor(\n",
    "    audio=sample[\"array\"],\n",
    "    sampling_rate=sample[\"sampling_rate\"],\n",
    "    text=[\"80s blues track with groovy saxophone\"],\n",
    "    padding=True,\n",
    "    return_tensors=\"pt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3787d4e6-6d1c-479b-8c92-c8a58d176144",
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
    "\n",
    "Audio(audio_values[0].asnumpy(), rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc",
   "metadata": {
    "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc"
   },
   "source": [
    "To demonstrate batched audio-prompted generation, we'll slice our sample audio by two different proportions to give two audio samples of different length.\n",
    "Since the input audio prompts vary in length, they will be *padded* to the length of the longest audio sample in the batch before being passed to the model.\n",
    "\n",
    "To recover the final audio samples, the `audio_values` generated can be post-processed to remove padding by using the processor class once again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5495f568-51ca-439d-b47b-8b52e89b78f1",
   "metadata": {
    "id": "5495f568-51ca-439d-b47b-8b52e89b78f1"
   },
   "outputs": [],
   "source": [
    "sample = next(iter(dataset.create_dict_iterator(output_numpy=True)))[\"audio\"]\n",
    "\n",
    "# take the first quater of the audio sample\n",
    "sample_1 = sample[\"array\"][: len(sample[\"array\"]) // 4]\n",
    "\n",
    "# take the first half of the audio sample\n",
    "sample_2 = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
    "\n",
    "inputs = processor(\n",
    "    audio=[sample_1, sample_2],\n",
    "    sampling_rate=sample[\"sampling_rate\"],\n",
    "    text=[\"80s blues track with groovy saxophone\", \"90s rock song with loud guitars and heavy drums\"],\n",
    "    padding=True,\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "\n",
    "audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
    "\n",
    "# post-process to remove padding from the batched audio\n",
    "audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)\n",
    "\n",
    "Audio(audio_values[0], rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "viwTDmzl8ZDN",
   "metadata": {
    "id": "viwTDmzl8ZDN"
   },
   "source": [
    "## Generation Config\n",
    "\n",
    "The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired. Let's first inspect the default generation config:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0zM4notb8Y1g",
   "metadata": {
    "id": "0zM4notb8Y1g"
   },
   "outputs": [],
   "source": [
    "model.generation_config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "DLSnSwau8jyW",
   "metadata": {
    "id": "DLSnSwau8jyW"
   },
   "source": [
    "Alright! We see that the model defaults to using sampling mode (`do_sample=True`), a guidance scale of 3, and a maximum generation length of 1500 (which is equivalent to 30s of audio). You can update any of these attributes to change the default generation parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ensSj1IB81dA",
   "metadata": {
    "id": "ensSj1IB81dA"
   },
   "outputs": [],
   "source": [
    "# increase the guidance scale to 4.0\n",
    "model.generation_config.guidance_scale = 4.0\n",
    "\n",
    "# set the max new tokens to 256\n",
    "model.generation_config.max_new_tokens = 256\n",
    "\n",
    "# set the softmax sampling temperature to 1.5\n",
    "model.generation_config.temperature = 1.5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UjqGnfc-9ZFJ",
   "metadata": {
    "id": "UjqGnfc-9ZFJ"
   },
   "source": [
    "Re-running generation now will use the newly defined values in the generation config:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "KAExrhDl9YvS",
   "metadata": {
    "id": "KAExrhDl9YvS"
   },
   "outputs": [],
   "source": [
    "audio_values = model.generate(**inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HdGdoGAs84hS",
   "metadata": {
    "id": "HdGdoGAs84hS"
   },
   "source": [
    "Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting `do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the generation config."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
